import Optimizer
import random

# this optimizer performs a generational genetic optimization over a given number of generations
# it uses uniform crossover
# generation size MUST be an even number
class GeneticAlgorithm(Optimizer.Optimizer):
    # num_batches: how many batchs to iterate over
    # states_per_batch: how many states per batch
    def __init__(self, mutation_prob, num_generations, generation_size, num_bits, number_top_states, characterizer, initial_population = None, verbose=False):
        super(GeneticAlgorithm,self).__init__(number_top_states,characterizer,verbose)
        self.verbose = verbose

        self.num_generations = num_generations
        self.generation_size = generation_size
        self.maximum_state =  2**num_bits-1
        self.num_bits = num_bits

        self.initial_population = initial_population

        self.generation_number = 0
        self.population = []

        self.p_mutate = mutation_prob

############### INTERFACE FUNCTIONS ###############

    def isFinished(self):
        if self.verbose:
            print "Checking if Finished"
        return self.generation_number >= self.num_generations

    def getNextStates(self):
        if self.verbose:
            print "Getting Next States"
        if len(self.population) == 0:
            print "Initializing Population..."
            self.population = self.getInitialPopulation()
        else:
            print "Selecting Parents..."
            parents = self.selection()
            #print parents
            print "Performing Recombination..."
            children = self.recombination(parents)
            #print children
            print "Mutating Children..."
            self.population = self.mutation(children)
        self.generation_number += 1
        return self.population

############### Child Class Helper Functions ###############

    # return a list of uniformly random states
    def getInitialPopulation(self):
        if self.initial_population[0] is None:
            population = [random.randint(0, self.maximum_state) for i in range(self.generation_size)]
        else:
            population = self.initial_population
            num_needed = self.generation_size - len(population)
            if num_needed < 0:
                population = population[:self.generation_size]
            if num_needed > 0:
                for idx in range(num_needed):
                    # pick a random existing state
                    state = random.sample(population, 1)
                    # convert to binary
                    bin_state = bin(state[0])
                    state_bit_count = len(bin_state)-2
                    if state_bit_count < self.num_bits:
                        missing_bits = self.num_bits - state_bit_count
                        bin_state = bin_state[:2] + '0'*missing_bits + bin_state[2:]
                    # flip up to half its bits
                    mutated_state = '0b'
                    # get potential bit flip idx
                    bits_to_flip = random.sample(range(self.num_bits), self.num_bits/2)
                    for bit_idx in bits_to_flip:
                        state_bit = bin_state[2+bit_idx]
                        # randomly flip bit
                        coin_flip = random.random()
                        if coin_flip >= 0.5:
                            if state_bit == '0':
                                mutated_state += '1'
                            else:
                                mutated_state += '0'
                        else:
                            mutated_state += state_bit
                    mutated_state = int(mutated_state,2)
                    population += [mutated_state]
        return population

    def selection(self, method ='ranked'):
        population_fitness = [self.explored_states[state] for state in self.population]
        # Sort the population from most to least fit
        ordered_population = sorted(zip(population_fitness, self.population),reverse=True)
        if method == 'ranked':
            for i in range(len(ordered_population)):
                ordered_population[i] = (ordered_population[i][1],len(ordered_population)-i)
        if method == 'fitness':
            min_fitness = min(population_fitness)
            for i in range(len(ordered_parents)):
                ordered_population[i] = (ordered_population[i][1],ordered_population[i][0] + min_fitness + 1)
        total_weight = 0
        for member in ordered_population:
            total_weight += member[1]
        parents = []
        next_generation_size = 0
        while next_generation_size < self.generation_size:

            parent_a = ordered_population[0][0]
            parent_b = ordered_population[0][0]
            lotto_number_a = random.uniform(0, total_weight)
            # not the most efficent way to do this but I think it works
            while parent_a == parent_b:
                lotto_number_b = random.uniform(0, total_weight)
                lotto_draw = 0
                for member_idx in range(len(ordered_population)):
                    parent_candidate = ordered_population[member_idx][0]
                    parent_candidate_weight = ordered_population[member_idx][1]
                    lotto_draw += parent_candidate_weight
                    if lotto_draw <= lotto_number_a:
                        parent_a = parent_candidate
                    if lotto_draw <= lotto_number_b:
                        parent_b = parent_candidate

            parent_pair = (parent_a,parent_b)
            rev_parent_pair = (parent_b,parent_a)
            parents += [parent_pair]
            next_generation_size = next_generation_size + 2
        return parents

    # uniform random recombination
    def recombination(self,parents):
        children = []
        for parent_pair in parents:
            parent_a = bin(parent_pair[0])
            parent_b = bin(parent_pair[1])
            a_bits = len(parent_a)-2
            b_bits = len(parent_b)-2
            if a_bits < self.num_bits:
                missing_bits = self.num_bits - a_bits
                parent_a = parent_a[:2] + '0'*missing_bits + parent_a[2:]
            if b_bits < self.num_bits:
                missing_bits = self.num_bits - b_bits
                parent_b = parent_b[:2] + '0'*missing_bits + parent_b[2:]
            kids = ['0b','0b']
            for bit_idx in range(2,2+self.num_bits):
                bit_a = parent_a[bit_idx]
                bit_b = parent_b[bit_idx]
                kid_a_idx = random.randint(0,1)
                kid_b_idx = 1 - kid_a_idx
                kids[kid_a_idx] += bit_a
                kids[kid_b_idx] += bit_b
            children += kids
        return children

    def mutation(self,children):
        population = []
        for kid in children:
            mutant_kid = None
            # repeat mutate process until mutant kid is not in the tabu_list
            # this will force mutations when there are only a couple of states left
            while mutant_kid is None or mutant_kid in self.explored_states.keys():
                mutant_kid = '0b'
                for bit_idx in range(2,2+self.num_bits):
                    kid_bit = kid[bit_idx]
                    coin_flip = random.random()
                    if coin_flip < self.p_mutate:
                        if kid_bit == '0':
                            mutant_kid += '1'
                        else:
                            mutant_kid += '0'
                    else:
                        mutant_kid += kid_bit
                mutant_kid = int(mutant_kid,2)
            population += [mutant_kid]
        return population
